#!/usr/bin/env python3
"""
scatterplot_state.py

Scatterplot of CPU vs GPU mean wordcount per item for a chosen group (Warm/Cold/Alt),
with different markers per regime, slight transparency, and a 45-degree reference line.

Defaults to Warm group using tags W2/W3/W4 and Alt tags A1/A2/A3, Cold tag C1.

Usage:
  python scatterplot_state.py --base_dir "/path/to/logs" --group Warm --out cpu_gpu_scatter_warm.png
"""

import argparse
import re
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

ITEM_SPLIT_RE = re.compile(r"(?m)^\s*(\d{1,2})\.\s")

CONDITIONS = {
    "BaselineNoJSON": {"marker": "o"},
    "BaselineWithJSON": {"marker": "D"},
    "HighRigor": {"marker": "s"},
    "PushbackStrong": {"marker": "^"},
}

def trim_item50_noise(text: str) -> str:
    m = re.search(r"(?ims)\r?\n\r?\n+please note\b", text)
    if m:
        return text[:m.start()].rstrip()
    return text.rstrip()

def extract_items(text: str):
    splits = ITEM_SPLIT_RE.split(text)
    items = {}
    for i in range(1, len(splits), 2):
        num = int(splits[i])
        content = splits[i + 1].strip()
        items[num] = content
    if 50 in items:
        items[50] = trim_item50_noise(items[50])
    expected = set(range(1, 51))
    if set(items.keys()) != expected:
        raise ValueError("Bad item keys")
    return items

def mean_wordcounts(paths):
    runs = []
    for p in paths:
        text = Path(p).read_text(encoding="utf-8", errors="ignore")
        items = extract_items(text)
        runs.append({i: len(items[i].split()) for i in range(1, 51)})
    means = []
    for i in range(1, 51):
        means.append(np.mean([r[i] for r in runs]))
    return means

def build_filename(condition, hardware, tag):
    return f"{condition}_{hardware}_T0_{tag}.txt"

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_dir", required=True)
    ap.add_argument("--group", choices=["Warm", "Cold", "Alt"], default="Warm")
    ap.add_argument("--warm_tags", nargs="+", default=["W2","W3","W4"])
    ap.add_argument("--alt_tags", nargs="+", default=["A1","A2","A3"])
    ap.add_argument("--cold_tag", default="C1")
    ap.add_argument("--out", default="cpu_gpu_scatter.png")
    args = ap.parse_args()

    base = Path(args.base_dir)

    data = {}
    for cond in CONDITIONS:
        if args.group == "Warm":
            cpu_files = [base / build_filename(cond, "CPU", t) for t in args.warm_tags]
            gpu_files = [base / build_filename(cond, "GPU", t) for t in args.warm_tags]
        elif args.group == "Alt":
            cpu_files = [base / build_filename(cond, "CPU", t) for t in args.alt_tags]
            gpu_files = [base / build_filename(cond, "GPU", t) for t in args.alt_tags]
        else:  # Cold
            cpu_files = [base / build_filename(cond, "CPU", args.cold_tag)]
            gpu_files = [base / build_filename(cond, "GPU", args.cold_tag)]

        for p in cpu_files + gpu_files:
            if not p.exists():
                raise FileNotFoundError(f"Missing: {p}")

        cpu_vals = mean_wordcounts(cpu_files)
        gpu_vals = mean_wordcounts(gpu_files)
        data[cond] = (cpu_vals, gpu_vals)

    plt.figure()
    for cond, (cpu_vals, gpu_vals) in data.items():
        plt.scatter(cpu_vals, gpu_vals, marker=CONDITIONS[cond]["marker"], alpha=0.6, label=cond)

    all_cpu = np.concatenate([np.array(v[0]) for v in data.values()])
    all_gpu = np.concatenate([np.array(v[1]) for v in data.values()])
    min_val = min(all_cpu.min(), all_gpu.min())
    max_val = max(all_cpu.max(), all_gpu.max())
    plt.plot([min_val, max_val], [min_val, max_val])

    plt.xlabel("CPU Mean Word Count per Item")
    plt.ylabel("GPU Mean Word Count per Item")
    plt.title(f"CPU vs GPU Structural Divergence ({args.group})")
    plt.legend()
    plt.tight_layout()
    plt.savefig(str(base / args.out), dpi=300)
    plt.show()

if __name__ == "__main__":
    main()
